{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H8kUCL8VpWJN"
      },
      "source": [
        "<span style=\"color: red;\">Requirement when running in Goolge Colab</span>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OILt1P6RpWJO"
      },
      "outputs": [],
      "source": [
        "!pip install diffusers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d6xsY5hUpWJP"
      },
      "source": [
        "#  Chapter 5 - Replace Attention Layers\n",
        "\n",
        "So far we learnt how to generate images and how to break the model down all the way to its attention layer. As mentioned in Stable Diffusion in its U-Net architecture has self-attention and cross-attention and a resent on top of that to decrease or increase dimensionality based of whether its in encoder or decoder stage, so we are dealing with different dimensions of attention maps as well.\n",
        "\n",
        "The idea in this chapter is to generate a synthetic image with a prompt and an editted version of that prompt and during generation to replace some dimension of the cross-attention maps and see the relevant results and see if we can achieve similiar images with the required editting"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7KF8rSaapWJQ"
      },
      "source": [
        "This part, as before, it's been copied from the previous chapter and you can and run and move to the next cell"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gtn4TVEFpWJR"
      },
      "outputs": [],
      "source": [
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "from diffusers import StableDiffusionPipeline\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "from typing import Optional\n",
        "from tqdm import tqdm\n",
        "\n",
        "\n",
        "model_id = \"stabilityai/stable-diffusion-2-1-base\"\n",
        "\n",
        "pipe = StableDiffusionPipeline.from_pretrained(model_id)\n",
        "pipe = pipe.to(\"cuda\")\n",
        "\n",
        "prompt = \"A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "prompt_embeds = pipe.encode_prompt(prompt=prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "\n",
        "initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to(\"cuda\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9yn8T5LapWJR"
      },
      "outputs": [],
      "source": [
        "new_prompt = \"A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background\"\n",
        "\n",
        "new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device=\"cuda\", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=\"\")\n",
        "prompt_embeds_combined = torch.cat([prompt_embeds[1], prompt_embeds[0], new_prompt_embeds[1], new_prompt_embeds[0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bt6ozSnnpWJR"
      },
      "source": [
        "A Modification of the contextual_forward from the previous section with the addition of replacing the target self-attention maps with the ones of the source and we used a specific configuration that granted the best results for our task at hand"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UfGUd6RGpWJR"
      },
      "outputs": [],
      "source": [
        "\n",
        "def contextual_forward(self, source_batch_index, target_batch_index, skip_dimension, should_replace = False):\n",
        "\n",
        "    def forward_modified(\n",
        "        hidden_states: torch.FloatTensor,\n",
        "        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n",
        "        attention_mask: Optional[torch.FloatTensor] = None,\n",
        "        temb: Optional[torch.FloatTensor] = None,\n",
        "    ) -> torch.FloatTensor:\n",
        "\n",
        "            residual = hidden_states\n",
        "\n",
        "            is_cross = not encoder_hidden_states is None\n",
        "\n",
        "            input_ndim = hidden_states.ndim\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                batch_size, channel, height, width = hidden_states.shape\n",
        "                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
        "\n",
        "            batch_size, _, _ = (\n",
        "                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
        "            )\n",
        "\n",
        "            if self.group_norm is not None:\n",
        "                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
        "\n",
        "            query = self.to_q(hidden_states)\n",
        "\n",
        "            if encoder_hidden_states is None:\n",
        "                encoder_hidden_states = hidden_states\n",
        "            elif self.norm_cross:\n",
        "                encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)\n",
        "\n",
        "            key = self.to_k(encoder_hidden_states)\n",
        "            value = self.to_v(encoder_hidden_states)\n",
        "\n",
        "            query = self.head_to_batch_dim(query)\n",
        "            key = self.head_to_batch_dim(key)\n",
        "            value = self.head_to_batch_dim(value)\n",
        "\n",
        "            attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))\n",
        "\n",
        "            #############################################################\n",
        "            ### The replacing process of attention maps happens here    ###\n",
        "            #############################################################\n",
        "\n",
        "            # each attention_scores comes with a batch shape of heads * batch size and considering\n",
        "            # the unconditional prompt that we require for CFG technically if couple them up as one\n",
        "            # batch then technically it comes in the format of heads * batch size * 2 and we use that\n",
        "            # information to find the relevant indicies to each batch\n",
        "            num_heads = self.heads\n",
        "            source_starting_index = num_heads * source_batch_index * 2\n",
        "            source_ending_index = num_heads * (source_batch_index + 1) * 2\n",
        "            target_starting_index = num_heads * target_batch_index * 2\n",
        "            target_ending_index = num_heads * (target_batch_index + 1) * 2\n",
        "\n",
        "            dimension_squared = hidden_states.shape[1]\n",
        "\n",
        "            # our experiement showed that this is the combination granted the best results when it comes\n",
        "            # to facial related changes, hence this specific configuration for our replacement\n",
        "            if not is_cross and (should_replace or not dimension_squared == skip_dimension * skip_dimension):\n",
        "                attention_scores[target_starting_index:target_ending_index,:,:] = attention_scores[source_starting_index:source_ending_index,:,:]\n",
        "            #############################################################\n",
        "            attention_probs = attention_scores.softmax(dim=-1)\n",
        "            del attention_scores\n",
        "\n",
        "            hidden_states = torch.bmm(attention_probs, value)\n",
        "            hidden_states = self.batch_to_head_dim(hidden_states)\n",
        "            del attention_probs\n",
        "\n",
        "            hidden_states = self.to_out[0](hidden_states)\n",
        "\n",
        "            if input_ndim == 4:\n",
        "                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
        "\n",
        "            if self.residual_connection:\n",
        "                hidden_states = hidden_states + residual\n",
        "\n",
        "            hidden_states = hidden_states / self.rescale_output_factor\n",
        "\n",
        "            return hidden_states\n",
        "\n",
        "    return forward_modified"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J0Qz5JTfpWJS"
      },
      "source": [
        "To differniate between encoding state (Down) and decoding state (Up) of the U-Net attention checks needs to be added by using the class name:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G12kfcS2pWJS"
      },
      "outputs": [],
      "source": [
        "# source_batch_index and target_batch_index added to copy the relevant attention maps\n",
        "def apply_forward_function(unet, child = None, source_batch_index=-1, target_batch_index=-1, should_replace= False):\n",
        "    if child == None:\n",
        "        children = unet.named_children()\n",
        "        for child in children:\n",
        "            apply_forward_function(unet, child[1], source_batch_index, target_batch_index, should_replace)\n",
        "    else:\n",
        "        if child.__class__.__name__ == 'Attention':\n",
        "            child.forward = contextual_forward(child, source_batch_index, target_batch_index, pipe.unet.config.sample_size, should_replace)\n",
        "        elif hasattr(child, 'children'):\n",
        "            for sub_child in child.children():\n",
        "                block_name = child.__class__.__name__\n",
        "\n",
        "                if \"Down\" in block_name:\n",
        "                    should_replace = True\n",
        "                elif \"Up\" in block_name:\n",
        "                    should_replace = False\n",
        "                elif \"Mid\" in block_name:\n",
        "                    should_replace = True\n",
        "                apply_forward_function(unet, sub_child, source_batch_index, target_batch_index, should_replace)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f78f_Yf2pWJS"
      },
      "source": [
        "A wrapper for displaying the latents, considering we are dealing with 2 images this time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gvsrry29pWJT"
      },
      "outputs": [],
      "source": [
        "def display_latents(latents):\n",
        "    with torch.no_grad():\n",
        "        image_0 = pipe.vae.decode(latents[0].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()\n",
        "        image_np_0 = (image_np_0 - image_np_0.min()) / (image_np_0.max() - image_np_0.min())\n",
        "\n",
        "        image_1 = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]\n",
        "        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()\n",
        "        image_np_1 = (image_np_1 - image_np_1.min()) / (image_np_1.max() - image_np_1.min())\n",
        "\n",
        "        fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
        "\n",
        "        axes[0].imshow(image_np_0)\n",
        "        axes[0].axis('off')\n",
        "        axes[0].set_title('Latent 0')\n",
        "\n",
        "        axes[1].imshow(image_np_1)\n",
        "        axes[1].axis('off')\n",
        "        axes[1].set_title('Latent 1')\n",
        "\n",
        "        plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_j5GbxqMpWJT"
      },
      "source": [
        "# Displaying the sampling results WITHOUT replacing the attention layer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dv8BMAmgpWJU"
      },
      "outputs": [],
      "source": [
        "num_inference_steps = 50\n",
        "guidance_scale = 7.5\n",
        "pipe.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = pipe.scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = pipe.scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ltILS8wRpWJU"
      },
      "source": [
        "# Displaying the sampling results WITH replacing the attention layer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2O4Pl6zJpWJU"
      },
      "outputs": [],
      "source": [
        "num_inference_steps = 50\n",
        "guidance_scale = 7.5\n",
        "pipe.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n",
        "timesteps = pipe.scheduler.timesteps\n",
        "\n",
        "latents = torch.cat([initial_latents] * 2)\n",
        "\n",
        "with torch.no_grad():\n",
        "\n",
        "    # we want to replace the self-attention maps of the batch with the new prompt with the ones from the old prompt\n",
        "    # therefore copying 1 to 0\n",
        "    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)\n",
        "\n",
        "    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=\"Inference steps\"):\n",
        "\n",
        "        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])\n",
        "\n",
        "        noise_pred = pipe.unet(\n",
        "            latent_model_input,\n",
        "            t,\n",
        "            encoder_hidden_states=prompt_embeds_combined,\n",
        "            cross_attention_kwargs=None,\n",
        "            return_dict=False,\n",
        "        )[0]\n",
        "\n",
        "\n",
        "        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)\n",
        "\n",
        "        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
        "        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)\n",
        "\n",
        "        latents = pipe.scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]\n",
        "\n",
        "    display_latents(latents)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kulqLT_TpWJU"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "L4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
